1   /*
2    * Copyright (C) 2007 The Guava Authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.google.common.collect;
18  
19  import static com.google.common.base.Preconditions.checkArgument;
20  import static com.google.common.base.Preconditions.checkState;
21  import static com.google.common.collect.CollectPreconditions.checkNonnegative;
22  import static com.google.common.collect.CollectPreconditions.checkRemove;
23  
24  import com.google.common.annotations.GwtCompatible;
25  import com.google.common.base.MoreObjects;
26  import com.google.common.primitives.Ints;
27  
28  import java.io.Serializable;
29  import java.util.Comparator;
30  import java.util.ConcurrentModificationException;
31  import java.util.Iterator;
32  import java.util.NoSuchElementException;
33  
34  import javax.annotation.Nullable;
35  
36  /**
37   * A multiset which maintains the ordering of its elements, according to either their natural order
38   * or an explicit {@link Comparator}. In all cases, this implementation uses
39   * {@link Comparable#compareTo} or {@link Comparator#compare} instead of {@link Object#equals} to
40   * determine equivalence of instances.
41   *
42   * <p><b>Warning:</b> The comparison must be <i>consistent with equals</i> as explained by the
43   * {@link Comparable} class specification. Otherwise, the resulting multiset will violate the
44   * {@link java.util.Collection} contract, which is specified in terms of {@link Object#equals}.
45   *
46   * <p>See the Guava User Guide article on <a href=
47   * "http://code.google.com/p/guava-libraries/wiki/NewCollectionTypesExplained#Multiset">
48   * {@code Multiset}</a>.
49   *
50   * @author Louis Wasserman
51   * @author Jared Levy
52   * @since 2.0 (imported from Google Collections Library)
53   */
54  @GwtCompatible(emulated = true)
55  public final class TreeMultiset<E> extends AbstractSortedMultiset<E> implements Serializable {
56  
57    /**
58     * Creates a new, empty multiset, sorted according to the elements' natural order. All elements
59     * inserted into the multiset must implement the {@code Comparable} interface. Furthermore, all
60     * such elements must be <i>mutually comparable</i>: {@code e1.compareTo(e2)} must not throw a
61     * {@code ClassCastException} for any elements {@code e1} and {@code e2} in the multiset. If the
62     * user attempts to add an element to the multiset that violates this constraint (for example,
63     * the user attempts to add a string element to a set whose elements are integers), the
64     * {@code add(Object)} call will throw a {@code ClassCastException}.
65     *
66     * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
67     * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
68     */
69    public static <E extends Comparable> TreeMultiset<E> create() {
70      return new TreeMultiset<E>(Ordering.natural());
71    }
72  
73    /**
74     * Creates a new, empty multiset, sorted according to the specified comparator. All elements
75     * inserted into the multiset must be <i>mutually comparable</i> by the specified comparator:
76     * {@code comparator.compare(e1,
77     * e2)} must not throw a {@code ClassCastException} for any elements {@code e1} and {@code e2} in
78     * the multiset. If the user attempts to add an element to the multiset that violates this
79     * constraint, the {@code add(Object)} call will throw a {@code ClassCastException}.
80     *
81     * @param comparator
82     *          the comparator that will be used to sort this multiset. A null value indicates that
83     *          the elements' <i>natural ordering</i> should be used.
84     */
85    @SuppressWarnings("unchecked")
86    public static <E> TreeMultiset<E> create(@Nullable Comparator<? super E> comparator) {
87      return (comparator == null)
88          ? new TreeMultiset<E>((Comparator) Ordering.natural())
89          : new TreeMultiset<E>(comparator);
90    }
91  
92    /**
93     * Creates an empty multiset containing the given initial elements, sorted according to the
94     * elements' natural order.
95     *
96     * <p>This implementation is highly efficient when {@code elements} is itself a {@link Multiset}.
97     *
98     * <p>The type specification is {@code <E extends Comparable>}, instead of the more specific
99     * {@code <E extends Comparable<? super E>>}, to support classes defined without generics.
100    */
101   public static <E extends Comparable> TreeMultiset<E> create(Iterable<? extends E> elements) {
102     TreeMultiset<E> multiset = create();
103     Iterables.addAll(multiset, elements);
104     return multiset;
105   }
106 
107   private final transient Reference<AvlNode<E>> rootReference;
108   private final transient GeneralRange<E> range;
109   private final transient AvlNode<E> header;
110 
111   TreeMultiset(Reference<AvlNode<E>> rootReference, GeneralRange<E> range, AvlNode<E> endLink) {
112     super(range.comparator());
113     this.rootReference = rootReference;
114     this.range = range;
115     this.header = endLink;
116   }
117 
118   TreeMultiset(Comparator<? super E> comparator) {
119     super(comparator);
120     this.range = GeneralRange.all(comparator);
121     this.header = new AvlNode<E>(null, 1);
122     successor(header, header);
123     this.rootReference = new Reference<AvlNode<E>>();
124   }
125 
126   /**
127    * A function which can be summed across a subtree.
128    */
129   private enum Aggregate {
130     SIZE {
131       @Override
132       int nodeAggregate(AvlNode<?> node) {
133         return node.elemCount;
134       }
135 
136       @Override
137       long treeAggregate(@Nullable AvlNode<?> root) {
138         return (root == null) ? 0 : root.totalCount;
139       }
140     },
141     DISTINCT {
142       @Override
143       int nodeAggregate(AvlNode<?> node) {
144         return 1;
145       }
146 
147       @Override
148       long treeAggregate(@Nullable AvlNode<?> root) {
149         return (root == null) ? 0 : root.distinctElements;
150       }
151     };
152     abstract int nodeAggregate(AvlNode<?> node);
153 
154     abstract long treeAggregate(@Nullable AvlNode<?> root);
155   }
156 
157   private long aggregateForEntries(Aggregate aggr) {
158     AvlNode<E> root = rootReference.get();
159     long total = aggr.treeAggregate(root);
160     if (range.hasLowerBound()) {
161       total -= aggregateBelowRange(aggr, root);
162     }
163     if (range.hasUpperBound()) {
164       total -= aggregateAboveRange(aggr, root);
165     }
166     return total;
167   }
168 
169   private long aggregateBelowRange(Aggregate aggr, @Nullable AvlNode<E> node) {
170     if (node == null) {
171       return 0;
172     }
173     int cmp = comparator().compare(range.getLowerEndpoint(), node.elem);
174     if (cmp < 0) {
175       return aggregateBelowRange(aggr, node.left);
176     } else if (cmp == 0) {
177       switch (range.getLowerBoundType()) {
178         case OPEN:
179           return aggr.nodeAggregate(node) + aggr.treeAggregate(node.left);
180         case CLOSED:
181           return aggr.treeAggregate(node.left);
182         default:
183           throw new AssertionError();
184       }
185     } else {
186       return aggr.treeAggregate(node.left) + aggr.nodeAggregate(node)
187           + aggregateBelowRange(aggr, node.right);
188     }
189   }
190 
191   private long aggregateAboveRange(Aggregate aggr, @Nullable AvlNode<E> node) {
192     if (node == null) {
193       return 0;
194     }
195     int cmp = comparator().compare(range.getUpperEndpoint(), node.elem);
196     if (cmp > 0) {
197       return aggregateAboveRange(aggr, node.right);
198     } else if (cmp == 0) {
199       switch (range.getUpperBoundType()) {
200         case OPEN:
201           return aggr.nodeAggregate(node) + aggr.treeAggregate(node.right);
202         case CLOSED:
203           return aggr.treeAggregate(node.right);
204         default:
205           throw new AssertionError();
206       }
207     } else {
208       return aggr.treeAggregate(node.right) + aggr.nodeAggregate(node)
209           + aggregateAboveRange(aggr, node.left);
210     }
211   }
212 
213   @Override
214   public int size() {
215     return Ints.saturatedCast(aggregateForEntries(Aggregate.SIZE));
216   }
217 
218   @Override
219   int distinctElements() {
220     return Ints.saturatedCast(aggregateForEntries(Aggregate.DISTINCT));
221   }
222 
223   @Override
224   public int count(@Nullable Object element) {
225     try {
226       @SuppressWarnings("unchecked")
227       E e = (E) element;
228       AvlNode<E> root = rootReference.get();
229       if (!range.contains(e) || root == null) {
230         return 0;
231       }
232       return root.count(comparator(), e);
233     } catch (ClassCastException e) {
234       return 0;
235     } catch (NullPointerException e) {
236       return 0;
237     }
238   }
239 
240   @Override
241   public int add(@Nullable E element, int occurrences) {
242     checkNonnegative(occurrences, "occurrences");
243     if (occurrences == 0) {
244       return count(element);
245     }
246     checkArgument(range.contains(element));
247     AvlNode<E> root = rootReference.get();
248     if (root == null) {
249       comparator().compare(element, element);
250       AvlNode<E> newRoot = new AvlNode<E>(element, occurrences);
251       successor(header, newRoot, header);
252       rootReference.checkAndSet(root, newRoot);
253       return 0;
254     }
255     int[] result = new int[1]; // used as a mutable int reference to hold result
256     AvlNode<E> newRoot = root.add(comparator(), element, occurrences, result);
257     rootReference.checkAndSet(root, newRoot);
258     return result[0];
259   }
260 
261   @Override
262   public int remove(@Nullable Object element, int occurrences) {
263     checkNonnegative(occurrences, "occurrences");
264     if (occurrences == 0) {
265       return count(element);
266     }
267     AvlNode<E> root = rootReference.get();
268     int[] result = new int[1]; // used as a mutable int reference to hold result
269     AvlNode<E> newRoot;
270     try {
271       @SuppressWarnings("unchecked")
272       E e = (E) element;
273       if (!range.contains(e) || root == null) {
274         return 0;
275       }
276       newRoot = root.remove(comparator(), e, occurrences, result);
277     } catch (ClassCastException e) {
278       return 0;
279     } catch (NullPointerException e) {
280       return 0;
281     }
282     rootReference.checkAndSet(root, newRoot);
283     return result[0];
284   }
285 
286   @Override
287   public int setCount(@Nullable E element, int count) {
288     checkNonnegative(count, "count");
289     if (!range.contains(element)) {
290       checkArgument(count == 0);
291       return 0;
292     }
293 
294     AvlNode<E> root = rootReference.get();
295     if (root == null) {
296       if (count > 0) {
297         add(element, count);
298       }
299       return 0;
300     }
301     int[] result = new int[1]; // used as a mutable int reference to hold result
302     AvlNode<E> newRoot = root.setCount(comparator(), element, count, result);
303     rootReference.checkAndSet(root, newRoot);
304     return result[0];
305   }
306 
307   @Override
308   public boolean setCount(@Nullable E element, int oldCount, int newCount) {
309     checkNonnegative(newCount, "newCount");
310     checkNonnegative(oldCount, "oldCount");
311     checkArgument(range.contains(element));
312 
313     AvlNode<E> root = rootReference.get();
314     if (root == null) {
315       if (oldCount == 0) {
316         if (newCount > 0) {
317           add(element, newCount);
318         }
319         return true;
320       } else {
321         return false;
322       }
323     }
324     int[] result = new int[1]; // used as a mutable int reference to hold result
325     AvlNode<E> newRoot = root.setCount(comparator(), element, oldCount, newCount, result);
326     rootReference.checkAndSet(root, newRoot);
327     return result[0] == oldCount;
328   }
329 
330   private Entry<E> wrapEntry(final AvlNode<E> baseEntry) {
331     return new Multisets.AbstractEntry<E>() {
332       @Override
333       public E getElement() {
334         return baseEntry.getElement();
335       }
336 
337       @Override
338       public int getCount() {
339         int result = baseEntry.getCount();
340         if (result == 0) {
341           return count(getElement());
342         } else {
343           return result;
344         }
345       }
346     };
347   }
348 
349   /**
350    * Returns the first node in the tree that is in range.
351    */
352   @Nullable private AvlNode<E> firstNode() {
353     AvlNode<E> root = rootReference.get();
354     if (root == null) {
355       return null;
356     }
357     AvlNode<E> node;
358     if (range.hasLowerBound()) {
359       E endpoint = range.getLowerEndpoint();
360       node = rootReference.get().ceiling(comparator(), endpoint);
361       if (node == null) {
362         return null;
363       }
364       if (range.getLowerBoundType() == BoundType.OPEN
365           && comparator().compare(endpoint, node.getElement()) == 0) {
366         node = node.succ;
367       }
368     } else {
369       node = header.succ;
370     }
371     return (node == header || !range.contains(node.getElement())) ? null : node;
372   }
373 
374   @Nullable private AvlNode<E> lastNode() {
375     AvlNode<E> root = rootReference.get();
376     if (root == null) {
377       return null;
378     }
379     AvlNode<E> node;
380     if (range.hasUpperBound()) {
381       E endpoint = range.getUpperEndpoint();
382       node = rootReference.get().floor(comparator(), endpoint);
383       if (node == null) {
384         return null;
385       }
386       if (range.getUpperBoundType() == BoundType.OPEN
387           && comparator().compare(endpoint, node.getElement()) == 0) {
388         node = node.pred;
389       }
390     } else {
391       node = header.pred;
392     }
393     return (node == header || !range.contains(node.getElement())) ? null : node;
394   }
395 
396   @Override
397   Iterator<Entry<E>> entryIterator() {
398     return new Iterator<Entry<E>>() {
399       AvlNode<E> current = firstNode();
400       Entry<E> prevEntry;
401 
402       @Override
403       public boolean hasNext() {
404         if (current == null) {
405           return false;
406         } else if (range.tooHigh(current.getElement())) {
407           current = null;
408           return false;
409         } else {
410           return true;
411         }
412       }
413 
414       @Override
415       public Entry<E> next() {
416         if (!hasNext()) {
417           throw new NoSuchElementException();
418         }
419         Entry<E> result = wrapEntry(current);
420         prevEntry = result;
421         if (current.succ == header) {
422           current = null;
423         } else {
424           current = current.succ;
425         }
426         return result;
427       }
428 
429       @Override
430       public void remove() {
431         checkRemove(prevEntry != null);
432         setCount(prevEntry.getElement(), 0);
433         prevEntry = null;
434       }
435     };
436   }
437 
438   @Override
439   Iterator<Entry<E>> descendingEntryIterator() {
440     return new Iterator<Entry<E>>() {
441       AvlNode<E> current = lastNode();
442       Entry<E> prevEntry = null;
443 
444       @Override
445       public boolean hasNext() {
446         if (current == null) {
447           return false;
448         } else if (range.tooLow(current.getElement())) {
449           current = null;
450           return false;
451         } else {
452           return true;
453         }
454       }
455 
456       @Override
457       public Entry<E> next() {
458         if (!hasNext()) {
459           throw new NoSuchElementException();
460         }
461         Entry<E> result = wrapEntry(current);
462         prevEntry = result;
463         if (current.pred == header) {
464           current = null;
465         } else {
466           current = current.pred;
467         }
468         return result;
469       }
470 
471       @Override
472       public void remove() {
473         checkRemove(prevEntry != null);
474         setCount(prevEntry.getElement(), 0);
475         prevEntry = null;
476       }
477     };
478   }
479 
480   @Override
481   public SortedMultiset<E> headMultiset(@Nullable E upperBound, BoundType boundType) {
482     return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.upTo(
483         comparator(),
484         upperBound,
485         boundType)), header);
486   }
487 
488   @Override
489   public SortedMultiset<E> tailMultiset(@Nullable E lowerBound, BoundType boundType) {
490     return new TreeMultiset<E>(rootReference, range.intersect(GeneralRange.downTo(
491         comparator(),
492         lowerBound,
493         boundType)), header);
494   }
495 
496   static int distinctElements(@Nullable AvlNode<?> node) {
497     return (node == null) ? 0 : node.distinctElements;
498   }
499 
500   private static final class Reference<T> {
501     @Nullable private T value;
502 
503     @Nullable public T get() {
504       return value;
505     }
506 
507     public void checkAndSet(@Nullable T expected, T newValue) {
508       if (value != expected) {
509         throw new ConcurrentModificationException();
510       }
511       value = newValue;
512     }
513   }
514 
515   private static final class AvlNode<E> extends Multisets.AbstractEntry<E> {
516     @Nullable private final E elem;
517 
518     // elemCount is 0 iff this node has been deleted.
519     private int elemCount;
520 
521     private int distinctElements;
522     private long totalCount;
523     private int height;
524     private AvlNode<E> left;
525     private AvlNode<E> right;
526     private AvlNode<E> pred;
527     private AvlNode<E> succ;
528 
529     AvlNode(@Nullable E elem, int elemCount) {
530       checkArgument(elemCount > 0);
531       this.elem = elem;
532       this.elemCount = elemCount;
533       this.totalCount = elemCount;
534       this.distinctElements = 1;
535       this.height = 1;
536       this.left = null;
537       this.right = null;
538     }
539 
540     public int count(Comparator<? super E> comparator, E e) {
541       int cmp = comparator.compare(e, elem);
542       if (cmp < 0) {
543         return (left == null) ? 0 : left.count(comparator, e);
544       } else if (cmp > 0) {
545         return (right == null) ? 0 : right.count(comparator, e);
546       } else {
547         return elemCount;
548       }
549     }
550 
551     private AvlNode<E> addRightChild(E e, int count) {
552       right = new AvlNode<E>(e, count);
553       successor(this, right, succ);
554       height = Math.max(2, height);
555       distinctElements++;
556       totalCount += count;
557       return this;
558     }
559 
560     private AvlNode<E> addLeftChild(E e, int count) {
561       left = new AvlNode<E>(e, count);
562       successor(pred, left, this);
563       height = Math.max(2, height);
564       distinctElements++;
565       totalCount += count;
566       return this;
567     }
568 
569     AvlNode<E> add(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
570       /*
571        * It speeds things up considerably to unconditionally add count to totalCount here,
572        * but that destroys failure atomicity in the case of count overflow. =(
573        */
574       int cmp = comparator.compare(e, elem);
575       if (cmp < 0) {
576         AvlNode<E> initLeft = left;
577         if (initLeft == null) {
578           result[0] = 0;
579           return addLeftChild(e, count);
580         }
581         int initHeight = initLeft.height;
582 
583         left = initLeft.add(comparator, e, count, result);
584         if (result[0] == 0) {
585           distinctElements++;
586         }
587         this.totalCount += count;
588         return (left.height == initHeight) ? this : rebalance();
589       } else if (cmp > 0) {
590         AvlNode<E> initRight = right;
591         if (initRight == null) {
592           result[0] = 0;
593           return addRightChild(e, count);
594         }
595         int initHeight = initRight.height;
596 
597         right = initRight.add(comparator, e, count, result);
598         if (result[0] == 0) {
599           distinctElements++;
600         }
601         this.totalCount += count;
602         return (right.height == initHeight) ? this : rebalance();
603       }
604 
605       // adding count to me!  No rebalance possible.
606       result[0] = elemCount;
607       long resultCount = (long) elemCount + count;
608       checkArgument(resultCount <= Integer.MAX_VALUE);
609       this.elemCount += count;
610       this.totalCount += count;
611       return this;
612     }
613 
614     AvlNode<E> remove(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
615       int cmp = comparator.compare(e, elem);
616       if (cmp < 0) {
617         AvlNode<E> initLeft = left;
618         if (initLeft == null) {
619           result[0] = 0;
620           return this;
621         }
622 
623         left = initLeft.remove(comparator, e, count, result);
624 
625         if (result[0] > 0) {
626           if (count >= result[0]) {
627             this.distinctElements--;
628             this.totalCount -= result[0];
629           } else {
630             this.totalCount -= count;
631           }
632         }
633         return (result[0] == 0) ? this : rebalance();
634       } else if (cmp > 0) {
635         AvlNode<E> initRight = right;
636         if (initRight == null) {
637           result[0] = 0;
638           return this;
639         }
640 
641         right = initRight.remove(comparator, e, count, result);
642 
643         if (result[0] > 0) {
644           if (count >= result[0]) {
645             this.distinctElements--;
646             this.totalCount -= result[0];
647           } else {
648             this.totalCount -= count;
649           }
650         }
651         return rebalance();
652       }
653 
654       // removing count from me!
655       result[0] = elemCount;
656       if (count >= elemCount) {
657         return deleteMe();
658       } else {
659         this.elemCount -= count;
660         this.totalCount -= count;
661         return this;
662       }
663     }
664 
665     AvlNode<E> setCount(Comparator<? super E> comparator, @Nullable E e, int count, int[] result) {
666       int cmp = comparator.compare(e, elem);
667       if (cmp < 0) {
668         AvlNode<E> initLeft = left;
669         if (initLeft == null) {
670           result[0] = 0;
671           return (count > 0) ? addLeftChild(e, count) : this;
672         }
673 
674         left = initLeft.setCount(comparator, e, count, result);
675 
676         if (count == 0 && result[0] != 0) {
677           this.distinctElements--;
678         } else if (count > 0 && result[0] == 0) {
679           this.distinctElements++;
680         }
681 
682         this.totalCount += count - result[0];
683         return rebalance();
684       } else if (cmp > 0) {
685         AvlNode<E> initRight = right;
686         if (initRight == null) {
687           result[0] = 0;
688           return (count > 0) ? addRightChild(e, count) : this;
689         }
690 
691         right = initRight.setCount(comparator, e, count, result);
692 
693         if (count == 0 && result[0] != 0) {
694           this.distinctElements--;
695         } else if (count > 0 && result[0] == 0) {
696           this.distinctElements++;
697         }
698 
699         this.totalCount += count - result[0];
700         return rebalance();
701       }
702 
703       // setting my count
704       result[0] = elemCount;
705       if (count == 0) {
706         return deleteMe();
707       }
708       this.totalCount += count - elemCount;
709       this.elemCount = count;
710       return this;
711     }
712 
713     AvlNode<E> setCount(
714         Comparator<? super E> comparator,
715         @Nullable E e,
716         int expectedCount,
717         int newCount,
718         int[] result) {
719       int cmp = comparator.compare(e, elem);
720       if (cmp < 0) {
721         AvlNode<E> initLeft = left;
722         if (initLeft == null) {
723           result[0] = 0;
724           if (expectedCount == 0 && newCount > 0) {
725             return addLeftChild(e, newCount);
726           }
727           return this;
728         }
729 
730         left = initLeft.setCount(comparator, e, expectedCount, newCount, result);
731 
732         if (result[0] == expectedCount) {
733           if (newCount == 0 && result[0] != 0) {
734             this.distinctElements--;
735           } else if (newCount > 0 && result[0] == 0) {
736             this.distinctElements++;
737           }
738           this.totalCount += newCount - result[0];
739         }
740         return rebalance();
741       } else if (cmp > 0) {
742         AvlNode<E> initRight = right;
743         if (initRight == null) {
744           result[0] = 0;
745           if (expectedCount == 0 && newCount > 0) {
746             return addRightChild(e, newCount);
747           }
748           return this;
749         }
750 
751         right = initRight.setCount(comparator, e, expectedCount, newCount, result);
752 
753         if (result[0] == expectedCount) {
754           if (newCount == 0 && result[0] != 0) {
755             this.distinctElements--;
756           } else if (newCount > 0 && result[0] == 0) {
757             this.distinctElements++;
758           }
759           this.totalCount += newCount - result[0];
760         }
761         return rebalance();
762       }
763 
764       // setting my count
765       result[0] = elemCount;
766       if (expectedCount == elemCount) {
767         if (newCount == 0) {
768           return deleteMe();
769         }
770         this.totalCount += newCount - elemCount;
771         this.elemCount = newCount;
772       }
773       return this;
774     }
775 
776     private AvlNode<E> deleteMe() {
777       int oldElemCount = this.elemCount;
778       this.elemCount = 0;
779       successor(pred, succ);
780       if (left == null) {
781         return right;
782       } else if (right == null) {
783         return left;
784       } else if (left.height >= right.height) {
785         AvlNode<E> newTop = pred;
786         // newTop is the maximum node in my left subtree
787         newTop.left = left.removeMax(newTop);
788         newTop.right = right;
789         newTop.distinctElements = distinctElements - 1;
790         newTop.totalCount = totalCount - oldElemCount;
791         return newTop.rebalance();
792       } else {
793         AvlNode<E> newTop = succ;
794         newTop.right = right.removeMin(newTop);
795         newTop.left = left;
796         newTop.distinctElements = distinctElements - 1;
797         newTop.totalCount = totalCount - oldElemCount;
798         return newTop.rebalance();
799       }
800     }
801 
802     // Removes the minimum node from this subtree to be reused elsewhere
803     private AvlNode<E> removeMin(AvlNode<E> node) {
804       if (left == null) {
805         return right;
806       } else {
807         left = left.removeMin(node);
808         distinctElements--;
809         totalCount -= node.elemCount;
810         return rebalance();
811       }
812     }
813 
814     // Removes the maximum node from this subtree to be reused elsewhere
815     private AvlNode<E> removeMax(AvlNode<E> node) {
816       if (right == null) {
817         return left;
818       } else {
819         right = right.removeMax(node);
820         distinctElements--;
821         totalCount -= node.elemCount;
822         return rebalance();
823       }
824     }
825 
826     private void recomputeMultiset() {
827       this.distinctElements = 1 + TreeMultiset.distinctElements(left)
828           + TreeMultiset.distinctElements(right);
829       this.totalCount = elemCount + totalCount(left) + totalCount(right);
830     }
831 
832     private void recomputeHeight() {
833       this.height = 1 + Math.max(height(left), height(right));
834     }
835 
836     private void recompute() {
837       recomputeMultiset();
838       recomputeHeight();
839     }
840 
841     private AvlNode<E> rebalance() {
842       switch (balanceFactor()) {
843         case -2:
844           if (right.balanceFactor() > 0) {
845             right = right.rotateRight();
846           }
847           return rotateLeft();
848         case 2:
849           if (left.balanceFactor() < 0) {
850             left = left.rotateLeft();
851           }
852           return rotateRight();
853         default:
854           recomputeHeight();
855           return this;
856       }
857     }
858 
859     private int balanceFactor() {
860       return height(left) - height(right);
861     }
862 
863     private AvlNode<E> rotateLeft() {
864       checkState(right != null);
865       AvlNode<E> newTop = right;
866       this.right = newTop.left;
867       newTop.left = this;
868       newTop.totalCount = this.totalCount;
869       newTop.distinctElements = this.distinctElements;
870       this.recompute();
871       newTop.recomputeHeight();
872       return newTop;
873     }
874 
875     private AvlNode<E> rotateRight() {
876       checkState(left != null);
877       AvlNode<E> newTop = left;
878       this.left = newTop.right;
879       newTop.right = this;
880       newTop.totalCount = this.totalCount;
881       newTop.distinctElements = this.distinctElements;
882       this.recompute();
883       newTop.recomputeHeight();
884       return newTop;
885     }
886 
887     private static long totalCount(@Nullable AvlNode<?> node) {
888       return (node == null) ? 0 : node.totalCount;
889     }
890 
891     private static int height(@Nullable AvlNode<?> node) {
892       return (node == null) ? 0 : node.height;
893     }
894 
895     @Nullable private AvlNode<E> ceiling(Comparator<? super E> comparator, E e) {
896       int cmp = comparator.compare(e, elem);
897       if (cmp < 0) {
898         return (left == null) ? this : MoreObjects.firstNonNull(left.ceiling(comparator, e), this);
899       } else if (cmp == 0) {
900         return this;
901       } else {
902         return (right == null) ? null : right.ceiling(comparator, e);
903       }
904     }
905 
906     @Nullable private AvlNode<E> floor(Comparator<? super E> comparator, E e) {
907       int cmp = comparator.compare(e, elem);
908       if (cmp > 0) {
909         return (right == null) ? this : MoreObjects.firstNonNull(right.floor(comparator, e), this);
910       } else if (cmp == 0) {
911         return this;
912       } else {
913         return (left == null) ? null : left.floor(comparator, e);
914       }
915     }
916 
917     @Override
918     public E getElement() {
919       return elem;
920     }
921 
922     @Override
923     public int getCount() {
924       return elemCount;
925     }
926 
927     @Override
928     public String toString() {
929       return Multisets.immutableEntry(getElement(), getCount()).toString();
930     }
931   }
932 
933   private static <T> void successor(AvlNode<T> a, AvlNode<T> b) {
934     a.succ = b;
935     b.pred = a;
936   }
937 
938   private static <T> void successor(AvlNode<T> a, AvlNode<T> b, AvlNode<T> c) {
939     successor(a, b);
940     successor(b, c);
941   }
942 
943   /*
944    * TODO(jlevy): Decide whether entrySet() should return entries with an equals() method that
945    * calls the comparator to compare the two keys. If that change is made,
946    * AbstractMultiset.equals() can simply check whether two multisets have equal entry sets.
947    */
948 }
949